"""Models Configuration Parameters."""

from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union

from dbgpt.core.interface.parameter import (
    BaseServerParameters,
    EmbeddingDeployModelParameters,
    LLMDeployModelParameters,
    RerankerDeployModelParameters,
)
from dbgpt.datasource.parameter import BaseDatasourceParameters
from dbgpt.util.configure.manager import RegisterParameters
from dbgpt.util.i18n_utils import _
from dbgpt.util.parameter_utils import BaseParameters


class WorkerType(str, Enum):
    LLM = "llm"
    TEXT2VEC = "text2vec"
    RERANKER = "reranker"

    @staticmethod
    def values():
        return [item.value for item in WorkerType]

    @staticmethod
    def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str:
        """Generate worker key from worker name and worker type

        Args:
            worker_name (str): Worker name(eg., chatglm2-6b)
            worker_type (Union[str, "WorkerType"]):
                Worker type(eg., 'llm', or [`WorkerType.LLM`])

        Returns:
            str: Generated worker key
        """
        if "@" in worker_name:
            raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}")
        if isinstance(worker_type, WorkerType):
            worker_type = worker_type.value
        return f"{worker_name}@{worker_type}"

    @staticmethod
    def parse_worker_key(worker_key: str) -> Tuple[str, str]:
        """Parse worker name and worker type from worker key

        Args:
            worker_key (str): Worker key generated by [`WorkerType.to_worker_key`]

        Returns:
            Tuple[str, str]: Worker name and worker type
        """
        return tuple(worker_key.split("@"))

    @classmethod
    def from_str(cls, value: str) -> "WorkerType":
        """Convert a string to an Enum value."""
        try:
            return cls(value)
        except ValueError:
            raise ValueError(
                f"Invalid value '{value}' for {cls.__name__}. "
                f"Valid values are {cls.values()}"
            )


@dataclass
class BaseModelRegistryParameters(BaseParameters, RegisterParameters):
    """Base model registry parameters."""

    __type__ = "___model_registry_placeholder___"

    __cfg_type__ = "service"

    @classmethod
    def _from_dict_(
        cls, data: Dict, prepare_data_func, converter
    ) -> Optional["BaseModelRegistryParameters"]:
        db = data.get("database", None)
        if db:
            real_data = prepare_data_func(BaseDatasourceParameters, data["database"])
            real_data["type"] = data["database"]["type"]
            database = converter(real_data, BaseDatasourceParameters)
            return DBModelRegistryParameters(database=database)
        return None


@dataclass
class DBModelRegistryParameters(BaseModelRegistryParameters):
    """Database model registry parameters."""

    database: Optional[BaseDatasourceParameters] = field(
        default=None, metadata={"help": _("Database configuration for model registry")}
    )


@dataclass
class ModelControllerParameters(BaseServerParameters):
    port: Optional[int] = field(
        default=8000, metadata={"help": _("Model Controller deploy port")}
    )
    registry: Optional[BaseModelRegistryParameters] = field(
        default=None,
        metadata={
            "help": _("Model registry configuration. If None, use embedded registry")
        },
    )

    heartbeat_interval_secs: Optional[int] = field(
        default=20,
        metadata={"help": _("The interval for checking heartbeats (seconds)")},
    )
    heartbeat_timeout_secs: Optional[int] = field(
        default=60,
        metadata={
            "help": _(
                "The timeout for checking heartbeats (seconds), it will be set "
                "unhealthy if the worker is not responding in this time"
            )
        },
    )


@dataclass
class ModelAPIServerParameters(BaseServerParameters):
    port: Optional[int] = field(
        default=8100, metadata={"help": _("Model API server deploy port")}
    )
    controller_addr: Optional[str] = field(
        default="http://127.0.0.1:8000",
        metadata={"help": _("The Model controller address to connect")},
    )

    api_keys: Optional[str] = field(
        default=None,
        metadata={"help": _("Optional list of comma separated API keys")},
    )
    embedding_batch_size: Optional[int] = field(
        default=None, metadata={"help": _("Embedding batch size")}
    )
    ignore_stop_exceeds_error: Optional[bool] = field(
        default=False, metadata={"help": _("Ignore exceeds stop words error")}
    )


@dataclass
class ModelWorkerParameters(BaseServerParameters):
    worker_type: Optional[str] = field(
        default=None,
        metadata={"valid_values": WorkerType.values(), "help": _("Worker type")},
    )
    worker_class: Optional[str] = field(
        default=None,
        metadata={
            "help": _("Model worker class, dbgpt.model.cluster.DefaultModelWorker")
        },
    )

    port: Optional[int] = field(
        default=8001, metadata={"help": _("Model worker deploy port")}
    )
    standalone: Optional[bool] = field(
        default=False,
        metadata={"help": _("Standalone mode. If True, embedded Run ModelController")},
    )
    register: Optional[bool] = field(
        default=True,
        metadata={"help": _("Register current worker to model controller")},
    )
    worker_register_host: Optional[str] = field(
        default=None,
        metadata={
            "help": _(
                "The ip address of current worker to register to ModelController. "
                "If None, the address is automatically determined"
            )
        },
    )
    controller_addr: Optional[str] = field(
        default=None, metadata={"help": _("The Model controller address to register")}
    )
    send_heartbeat: Optional[bool] = field(
        default=True, metadata={"help": _("Send heartbeat to model controller")}
    )
    heartbeat_interval: Optional[int] = field(
        default=20,
        metadata={"help": _("The interval for sending heartbeats (seconds)")},
    )


@dataclass
class ModelServiceConfig(BaseParameters):
    """Model service configuration."""

    __cfg_type__ = "service"

    worker: ModelWorkerParameters = field(
        default_factory=ModelWorkerParameters,
        metadata={"help": _("Model worker configuration")},
    )
    api: ModelAPIServerParameters = field(
        default_factory=ModelControllerParameters, metadata={"help": _("Model API")}
    )
    controller: ModelControllerParameters = field(
        default_factory=ModelControllerParameters,
        metadata={"help": _("Model controller")},
    )


@dataclass
class ModelsDeployParameters(BaseParameters):
    __cfg_type__ = "service"
    default_llm: Optional[str] = field(
        default=None,
        metadata={
            "help": _(
                "Default LLM model name, used to specify which model to use when you "
                "have multiple LLMs"
            ),
        },
    )
    default_embedding: Optional[str] = field(
        default=None,
        metadata={
            "help": _(
                "Default embedding model name, used to specify which model to use when "
                "you have multiple embedding models"
            ),
        },
    )
    default_reranker: Optional[str] = field(
        default=None,
        metadata={
            "help": _(
                "Default reranker model name, used to specify which model to use when "
                "you have multiple reranker models"
            ),
        },
    )
    llms: List[LLMDeployModelParameters] = field(
        default_factory=list,
        metadata={
            "help": _(
                "LLM model deploy configuration. If you deploy in cluster mode, you "
                "just deploy one model."
            )
        },
    )
    embeddings: List[EmbeddingDeployModelParameters] = field(
        default_factory=list,
        metadata={
            "help": _(
                "Embedding model deploy configuration. If you deploy in cluster "
                "mode, you just deploy one model."
            )
        },
    )
    rerankers: List[RerankerDeployModelParameters] = field(
        default_factory=list,
        metadata={
            "help": _(
                "Reranker model deploy configuration. If you deploy in cluster "
                "mode, you just deploy one model."
            )
        },
    )

    def __post_init__(self):
        """Post init method."""
        llms, embeds, rerankers = [], [], []
        for llm in self.llms:
            llms.append(llm.name)
        for embedding in self.embeddings:
            embeds.append(embedding.name)
        for reranker in self.rerankers:
            rerankers.append(reranker.name)
        if not self.default_llm and llms:
            self.default_llm = llms[0]
        if not self.default_embedding and embeds:
            self.default_embedding = embeds[0]
        if not self.default_reranker and rerankers:
            self.default_reranker = rerankers[0]
